package edu.berkeley.nlp.util;

import edu.berkeley.nlp.math.SloppyMath;

import java.util.*;

/**
 * @author Dan Klein
 */
public class Counters {
	public static <E> Counter<E> normalize(Counter<E> counter) {
		Counter<E> normalizedCounter = new Counter<E>();
		double total = counter.totalCount();
		for (E key : counter.keySet()) {
			normalizedCounter.setCount(key, counter.getCount(key) / total);
		}
		return normalizedCounter;
	}

	public static <T> Counter<T> counterFromCollection(Iterable<T> iterable) {
		Counter<T> counts = new Counter<T>();
		for (T t : iterable) {
			counts.incrementCount(t, 1.0);
		}
		return counts;
	}

	public static <E, C extends Iterable<?>> Counter<E> counterFromData(
			Iterable<? extends Collection<E>> iterable) {
		Counter<E> counts = new Counter<E>();
		for (Collection<E> coll : iterable) {
			counts.incrementAll(coll, 1.0);
		}
		return counts;
	}

	public static <K, V> CounterMap<K, V> conditionalNormalize(
			CounterMap<K, V> counterMap) {
		CounterMap<K, V> normalizedCounterMap = new CounterMap<K, V>();
		for (K key : counterMap.keySet()) {
			Counter<V> normalizedSubCounter = normalize(counterMap
					.getCounter(key));
			for (V value : normalizedSubCounter.keySet()) {
				double count = normalizedSubCounter.getCount(value);
				normalizedCounterMap.setCount(key, value, count);
			}
		}
		return normalizedCounterMap;
	}

	public static <K> double l2Norm(Counter<K> counts) {
		double sum = 0.0;
		for (Map.Entry<K, Double> entry : counts.getEntrySet()) {
			double count = entry.getValue();
			sum += count * count;
		}
		return Math.sqrt(sum);
	}

	public static <K> Counter<K> l2Normalize(Counter<K> counts) {
		Counter<K> normalizedCounts = new Counter<K>();
		double norm = 0.0;
		for (Map.Entry<K, Double> entry : counts.getEntrySet()) {
			double count = entry.getValue();
			norm += count * count;
		}
		norm = Math.sqrt(norm);
		if (norm == 0.0) {
			return normalizedCounts;
		}
		for (K key : counts.keySet()) {
			double count = counts.getCount(key);
			normalizedCounts.setCount(key, count / norm);
		}
		return normalizedCounts;
	}

	public static <L> List<L> sortedKeys(final Counter<L> counts) {
		List<L> keys = new ArrayList<L>();
		keys.addAll(counts.keySet());
		Collections.sort(keys, new Comparator<L>() {

			public int compare(L arg0, L arg1) {
				double diff = counts.getCount(arg1) - counts.getCount(arg0);
				if (diff < 0) {
					return -1;
				}
				if (diff == 0.0) {
					return 0;
				}
				return 1;
			}

		});
		return keys;
	}

	public static <K> Counter<K> exponentiate(Counter<K> counts) {
		Counter<K> exponentiated = new Counter<K>();
		for (Map.Entry<K, Double> entry : counts.entrySet()) {
			exponentiated.setCount(entry.getKey(), Math.exp(entry.getValue()));
		}
		return exponentiated;
	}

	public static <K> void exponentiateInPlace(Counter<K> counts) {
		for (Map.Entry<K, Double> entry : counts.entrySet()) {
			entry.setValue(Math.exp(entry.getValue()));
		}
	}

	public static <K> void logInPlace(Counter<K> counts) {
		for (Map.Entry<K, Double> entry : counts.entrySet()) {
			entry.setValue(Math.log(entry.getValue()));
		}
	}

	/**
	 * 
	 * @param logScores
	 * @param <K>
	 */
	public static <K> void makeProbsFromLogScoresInPlace(Counter<K> logScores) {
		double logSum = SloppyMath.logAdd(logScores);
		for (Map.Entry<K, Double> entry : logScores.entrySet()) {
			double logScore = entry.getValue();
			double prob = Math.exp(logScore - logSum);
			entry.setValue(prob);
		}
		logScores.setDirty(true);
	}
}
